import torch

if __name__ == '__main__':
    from torchvision.ops import nms, roi_align, roi_pool
    import torch

    # fp = torch.randn([1, 1, 5, 5])
    fp = torch.tensor(list(range(5 * 5))).float()
    fp = fp.view(1, 1, 5, 5)
    print(fp)
    # [batch_id, x1, y1, x2, y2]
    boxes = torch.tensor([[0, 0, 0, 1, 1]]).float()

    pooled_features = roi_align(fp, boxes, [4, 4])
    print(pooled_features)

    pooled_features = roi_pool(fp, boxes, [4, 4])
    print(pooled_features)
